{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66178523",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute this cell to install dependencies\n",
    "%pip install sf-hamilton[visualization]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97f4cc58",
   "metadata": {},
   "source": [
    "# Materialization [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/materialization/notebook.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/apache/hamilton/blob/main/examples/materialization/notebook.ipynb)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6e2fc99a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-04T17:30:02.674284Z",
     "start_time": "2023-09-04T17:29:59.371085Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import os\n",
    "\n",
    "import data_loaders\n",
    "import model_training\n",
    "\n",
    "from hamilton import base, driver\n",
    "from hamilton.io.materialization import to\n",
    "import pandas as pd\n",
    "\n",
    "import custom_materializers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "222756ff",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-04T17:30:02.682841Z",
     "start_time": "2023-09-04T17:30:02.679637Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Note: Hamilton collects completely anonymous data about usage. This will help us improve Hamilton over time. See https://github.com/apache/hamilton#usage-analytics--data-privacy for details.\n"
     ]
    }
   ],
   "source": [
    "dag_config = {\n",
    "    \"test_size_fraction\": 0.5,\n",
    "    \"shuffle_train_test_split\": True,\n",
    "    \"data_loader\" : \"iris\",\n",
    "    \"clf\" : \"logistic\",\n",
    "    \"penalty\" : \"l2\"\n",
    "}\n",
    "dr = (\n",
    "        driver.Builder()\n",
    "        .with_adapter(base.DefaultAdapter())\n",
    "        .with_config(dag_config)\n",
    "        .with_modules(data_loaders, model_training)\n",
    "        .build()\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2f0e75e3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-04T17:30:03.208108Z",
     "start_time": "2023-09-04T17:30:02.696138Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 8.0.5 (20230430.1635)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"787pt\" height=\"620pt\"\n",
       " viewBox=\"0.00 0.00 787.01 620.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 616)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-616 783.01,-616 783.01,4 -4,4\"/>\n",
       "<!-- classification_report_to_txt -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>classification_report_to_txt</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"746.7,-36 582.2,-36 582.2,0 746.7,0 746.7,-36\"/>\n",
       "<text text-anchor=\"middle\" x=\"664.45\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">classification_report_to_txt</text>\n",
       "</g>\n",
       "<!-- target_names -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>target_names</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"718.45\" cy=\"-378\" rx=\"60.56\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"718.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">target_names</text>\n",
       "</g>\n",
       "<!-- y_test_with_labels -->\n",
       "<g id=\"node14\" class=\"node\">\n",
       "<title>y_test_with_labels</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"627.45\" cy=\"-306\" rx=\"80.01\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"627.45\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test_with_labels</text>\n",
       "</g>\n",
       "<!-- target_names&#45;&gt;y_test_with_labels -->\n",
       "<g id=\"edge19\" class=\"edge\">\n",
       "<title>target_names&#45;&gt;y_test_with_labels</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M697.34,-360.76C685.57,-351.71 670.7,-340.27 657.7,-330.28\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"660.13,-326.96 650.07,-323.63 655.86,-332.5 660.13,-326.96\"/>\n",
       "</g>\n",
       "<!-- predicted_output_with_labels -->\n",
       "<g id=\"node20\" class=\"node\">\n",
       "<title>predicted_output_with_labels</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"516.45\" cy=\"-162\" rx=\"120.45\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"516.45\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output_with_labels</text>\n",
       "</g>\n",
       "<!-- target_names&#45;&gt;predicted_output_with_labels -->\n",
       "<g id=\"edge24\" class=\"edge\">\n",
       "<title>target_names&#45;&gt;predicted_output_with_labels</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M722.65,-359.66C726.24,-340.73 729.06,-310.16 716.45,-288 686.35,-235.12 623.81,-201.51 576.63,-182.62\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"578.1,-179.05 567.51,-178.7 575.57,-185.58 578.1,-179.05\"/>\n",
       "</g>\n",
       "<!-- X_train -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>X_train</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"427.45\" cy=\"-378\" rx=\"39.07\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"427.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_train</text>\n",
       "</g>\n",
       "<!-- fit_clf -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>fit_clf</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"331.45\" cy=\"-306\" rx=\"33.44\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"331.45\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">fit_clf</text>\n",
       "</g>\n",
       "<!-- X_train&#45;&gt;fit_clf -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>X_train&#45;&gt;fit_clf</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M407.09,-362.15C393.19,-352.02 374.61,-338.47 359.34,-327.34\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"361.94,-324.17 351.79,-321.11 357.81,-329.83 361.94,-324.17\"/>\n",
       "</g>\n",
       "<!-- train_test_split_func -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>train_test_split_func</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"460.45\" cy=\"-450\" rx=\"86.67\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"460.45\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">train_test_split_func</text>\n",
       "</g>\n",
       "<!-- train_test_split_func&#45;&gt;X_train -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>train_test_split_func&#45;&gt;X_train</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M452.29,-431.7C448.61,-423.9 444.19,-414.51 440.09,-405.83\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"442.94,-404.66 435.51,-397.1 436.61,-407.64 442.94,-404.66\"/>\n",
       "</g>\n",
       "<!-- y_train -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>y_train</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"333.45\" cy=\"-378\" rx=\"37.02\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"333.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_train</text>\n",
       "</g>\n",
       "<!-- train_test_split_func&#45;&gt;y_train -->\n",
       "<g id=\"edge12\" class=\"edge\">\n",
       "<title>train_test_split_func&#45;&gt;y_train</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M430.99,-432.76C411.81,-422.19 386.74,-408.38 366.82,-397.4\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"368.67,-393.87 358.22,-392.1 365.29,-400 368.67,-393.87\"/>\n",
       "</g>\n",
       "<!-- y_test -->\n",
       "<g id=\"node12\" class=\"node\">\n",
       "<title>y_test</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"573.45\" cy=\"-378\" rx=\"32.93\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"573.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">y_test</text>\n",
       "</g>\n",
       "<!-- train_test_split_func&#45;&gt;y_test -->\n",
       "<g id=\"edge16\" class=\"edge\">\n",
       "<title>train_test_split_func&#45;&gt;y_test</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M487.22,-432.41C503.97,-422.04 525.57,-408.66 542.97,-397.88\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"544.48,-400.44 551.14,-392.2 540.8,-394.49 544.48,-400.44\"/>\n",
       "</g>\n",
       "<!-- X_test -->\n",
       "<g id=\"node21\" class=\"node\">\n",
       "<title>X_test</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"494.45\" cy=\"-306\" rx=\"34.97\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"494.45\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">X_test</text>\n",
       "</g>\n",
       "<!-- train_test_split_func&#45;&gt;X_test -->\n",
       "<g id=\"edge25\" class=\"edge\">\n",
       "<title>train_test_split_func&#45;&gt;X_test</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M465.6,-431.89C468.63,-421.54 472.45,-408.07 475.45,-396 480.5,-375.62 485.43,-352.42 489,-334.81\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"492.6,-335.63 491.13,-325.14 485.74,-334.26 492.6,-335.63\"/>\n",
       "</g>\n",
       "<!-- shuffle_train_test_split -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>shuffle_train_test_split</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"120.45\" cy=\"-522\" rx=\"120.45\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"120.45\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: shuffle_train_test_split</text>\n",
       "</g>\n",
       "<!-- shuffle_train_test_split&#45;&gt;train_test_split_func -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>shuffle_train_test_split&#45;&gt;train_test_split_func</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M189.25,-506.83C247.61,-494.82 331,-477.65 389.86,-465.53\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"390.47,-468.77 399.56,-463.33 389.06,-461.92 390.47,-468.77\"/>\n",
       "</g>\n",
       "<!-- clf_to_pickle -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>clf_to_pickle</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"228.07,-252 140.82,-252 140.82,-216 228.07,-216 228.07,-252\"/>\n",
       "<text text-anchor=\"middle\" x=\"184.45\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">clf_to_pickle</text>\n",
       "</g>\n",
       "<!-- fit_clf&#45;&gt;clf_to_pickle -->\n",
       "<g id=\"edge13\" class=\"edge\">\n",
       "<title>fit_clf&#45;&gt;clf_to_pickle</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M306.84,-293.28C286.13,-283.42 255.95,-269.05 230.73,-257.04\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"232.52,-253.54 221.99,-252.4 229.52,-259.86 232.52,-253.54\"/>\n",
       "</g>\n",
       "<!-- model_parameters -->\n",
       "<g id=\"node13\" class=\"node\">\n",
       "<title>model_parameters</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"326.45\" cy=\"-234\" rx=\"80.01\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"326.45\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">model_parameters</text>\n",
       "</g>\n",
       "<!-- fit_clf&#45;&gt;model_parameters -->\n",
       "<g id=\"edge17\" class=\"edge\">\n",
       "<title>fit_clf&#45;&gt;model_parameters</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M330.21,-287.7C329.68,-280.24 329.04,-271.32 328.44,-262.97\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"331.87,-262.83 327.67,-253.1 324.89,-263.33 331.87,-262.83\"/>\n",
       "</g>\n",
       "<!-- predicted_output -->\n",
       "<g id=\"node22\" class=\"node\">\n",
       "<title>predicted_output</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"497.45\" cy=\"-234\" rx=\"73.36\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"497.45\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output</text>\n",
       "</g>\n",
       "<!-- fit_clf&#45;&gt;predicted_output -->\n",
       "<g id=\"edge26\" class=\"edge\">\n",
       "<title>fit_clf&#45;&gt;predicted_output</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M357.02,-294.22C382.19,-283.6 421.17,-267.17 451.68,-254.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"452.82,-257.2 460.67,-250.09 450.1,-250.75 452.82,-257.2\"/>\n",
       "</g>\n",
       "<!-- data -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>data</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"658.45\" cy=\"-594\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"658.45\" y=\"-588.95\" font-family=\"Times,serif\" font-size=\"14.00\">data</text>\n",
       "</g>\n",
       "<!-- data&#45;&gt;target_names -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>data&#45;&gt;target_names</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M673.08,-578.64C682.26,-568.69 693.46,-554.65 699.45,-540 717.45,-495.95 719.96,-439.86 719.59,-406.85\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"723.07,-407.11 719.36,-397.19 716.07,-407.27 723.07,-407.11\"/>\n",
       "</g>\n",
       "<!-- feature_matrix -->\n",
       "<g id=\"node15\" class=\"node\">\n",
       "<title>feature_matrix</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"324.45\" cy=\"-522\" rx=\"65.68\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"324.45\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">feature_matrix</text>\n",
       "</g>\n",
       "<!-- data&#45;&gt;feature_matrix -->\n",
       "<g id=\"edge20\" class=\"edge\">\n",
       "<title>data&#45;&gt;feature_matrix</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M632.67,-587.91C586.07,-578.67 484.69,-558.38 399.45,-540 394.49,-538.93 389.35,-537.8 384.19,-536.66\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"385.15,-533.07 374.63,-534.3 383.63,-539.9 385.15,-533.07\"/>\n",
       "</g>\n",
       "<!-- target -->\n",
       "<g id=\"node17\" class=\"node\">\n",
       "<title>target</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"658.45\" cy=\"-522\" rx=\"31.9\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"658.45\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">target</text>\n",
       "</g>\n",
       "<!-- data&#45;&gt;target -->\n",
       "<g id=\"edge21\" class=\"edge\">\n",
       "<title>data&#45;&gt;target</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M658.45,-575.7C658.45,-568.24 658.45,-559.32 658.45,-550.97\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"661.95,-551.1 658.45,-541.1 654.95,-551.1 661.95,-551.1\"/>\n",
       "</g>\n",
       "<!-- prefit_clf -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>prefit_clf</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"233.45\" cy=\"-378\" rx=\"45.21\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"233.45\" y=\"-372.95\" font-family=\"Times,serif\" font-size=\"14.00\">prefit_clf</text>\n",
       "</g>\n",
       "<!-- prefit_clf&#45;&gt;fit_clf -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>prefit_clf&#45;&gt;fit_clf</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M254.71,-361.81C268.92,-351.66 287.77,-338.19 303.24,-327.15\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"304.78,-329.63 310.89,-320.97 300.72,-323.93 304.78,-329.63\"/>\n",
       "</g>\n",
       "<!-- y_train&#45;&gt;fit_clf -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>y_train&#45;&gt;fit_clf</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M332.95,-359.7C332.74,-352.24 332.48,-343.32 332.24,-334.97\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"335.72,-335 331.93,-325.1 328.72,-335.2 335.72,-335\"/>\n",
       "</g>\n",
       "<!-- classification_report -->\n",
       "<g id=\"node11\" class=\"node\">\n",
       "<title>classification_report</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"727.57,-108 601.32,-108 601.32,-72 727.57,-72 727.57,-108\"/>\n",
       "<text text-anchor=\"middle\" x=\"664.45\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">classification_report</text>\n",
       "</g>\n",
       "<!-- classification_report&#45;&gt;classification_report_to_txt -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>classification_report&#45;&gt;classification_report_to_txt</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M664.45,-71.7C664.45,-64.24 664.45,-55.32 664.45,-46.97\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"667.95,-47.1 664.45,-37.1 660.95,-47.1 667.95,-47.1\"/>\n",
       "</g>\n",
       "<!-- y_test&#45;&gt;y_test_with_labels -->\n",
       "<g id=\"edge18\" class=\"edge\">\n",
       "<title>y_test&#45;&gt;y_test_with_labels</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M585.7,-361.12C592.27,-352.59 600.53,-341.89 607.97,-332.25\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"611.15,-334.85 614.49,-324.8 605.61,-330.58 611.15,-334.85\"/>\n",
       "</g>\n",
       "<!-- model_params_to_json -->\n",
       "<g id=\"node23\" class=\"node\">\n",
       "<title>model_params_to_json</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"378.2,-180 234.7,-180 234.7,-144 378.2,-144 378.2,-180\"/>\n",
       "<text text-anchor=\"middle\" x=\"306.45\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">model_params_to_json</text>\n",
       "</g>\n",
       "<!-- model_parameters&#45;&gt;model_params_to_json -->\n",
       "<g id=\"edge28\" class=\"edge\">\n",
       "<title>model_parameters&#45;&gt;model_params_to_json</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M321.5,-215.7C319.35,-208.15 316.77,-199.12 314.35,-190.68\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"317.45,-189.76 311.33,-181.1 310.71,-191.68 317.45,-189.76\"/>\n",
       "</g>\n",
       "<!-- y_test_with_labels&#45;&gt;classification_report -->\n",
       "<g id=\"edge15\" class=\"edge\">\n",
       "<title>y_test_with_labels&#45;&gt;classification_report</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M630.41,-287.85C636.78,-250.99 651.84,-163.92 659.61,-118.96\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"663.21,-119.68 661.47,-109.23 656.31,-118.49 663.21,-119.68\"/>\n",
       "</g>\n",
       "<!-- feature_matrix&#45;&gt;train_test_split_func -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>feature_matrix&#45;&gt;train_test_split_func</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M354.29,-505.64C373.46,-495.77 398.61,-482.83 419.62,-472.01\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"420.96,-474.74 428.25,-467.06 417.76,-468.52 420.96,-474.74\"/>\n",
       "</g>\n",
       "<!-- penalty -->\n",
       "<g id=\"node16\" class=\"node\">\n",
       "<title>penalty</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"233.45\" cy=\"-450\" rx=\"62.61\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"233.45\" y=\"-444.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: penalty</text>\n",
       "</g>\n",
       "<!-- penalty&#45;&gt;prefit_clf -->\n",
       "<g id=\"edge11\" class=\"edge\">\n",
       "<title>penalty&#45;&gt;prefit_clf</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M233.45,-431.7C233.45,-424.24 233.45,-415.32 233.45,-406.97\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"236.95,-407.1 233.45,-397.1 229.95,-407.1 236.95,-407.1\"/>\n",
       "</g>\n",
       "<!-- target&#45;&gt;train_test_split_func -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>target&#45;&gt;train_test_split_func</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M633.4,-510.3C628.15,-508.15 622.64,-505.95 617.45,-504 584.32,-491.56 546.78,-478.84 516.72,-468.98\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"518.01,-465.39 507.42,-465.61 515.84,-472.04 518.01,-465.39\"/>\n",
       "</g>\n",
       "<!-- predicted_output_with_labels_to_csv -->\n",
       "<g id=\"node18\" class=\"node\">\n",
       "<title>predicted_output_with_labels_to_csv</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"583.7,-108 365.2,-108 365.2,-72 583.7,-72 583.7,-108\"/>\n",
       "<text text-anchor=\"middle\" x=\"474.45\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">predicted_output_with_labels_to_csv</text>\n",
       "</g>\n",
       "<!-- test_size_fraction -->\n",
       "<g id=\"node19\" class=\"node\">\n",
       "<title>test_size_fraction</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"508.45\" cy=\"-522\" rx=\"100.48\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"508.45\" y=\"-516.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: test_size_fraction</text>\n",
       "</g>\n",
       "<!-- test_size_fraction&#45;&gt;train_test_split_func -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>test_size_fraction&#45;&gt;train_test_split_func</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M496.58,-503.7C491.06,-495.64 484.37,-485.89 478.26,-476.98\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"480.72,-475.37 472.17,-469.1 474.94,-479.33 480.72,-475.37\"/>\n",
       "</g>\n",
       "<!-- predicted_output_with_labels&#45;&gt;classification_report -->\n",
       "<g id=\"edge14\" class=\"edge\">\n",
       "<title>predicted_output_with_labels&#45;&gt;classification_report</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M551.52,-144.41C571.3,-135.06 596.26,-123.25 617.68,-113.12\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"619.01,-115.89 626.56,-108.45 616.02,-109.56 619.01,-115.89\"/>\n",
       "</g>\n",
       "<!-- predicted_output_with_labels&#45;&gt;predicted_output_with_labels_to_csv -->\n",
       "<g id=\"edge22\" class=\"edge\">\n",
       "<title>predicted_output_with_labels&#45;&gt;predicted_output_with_labels_to_csv</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M506.06,-143.7C501.28,-135.73 495.51,-126.1 490.2,-117.26\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"492.85,-115.88 484.71,-109.1 486.85,-119.48 492.85,-115.88\"/>\n",
       "</g>\n",
       "<!-- X_test&#45;&gt;predicted_output -->\n",
       "<g id=\"edge27\" class=\"edge\">\n",
       "<title>X_test&#45;&gt;predicted_output</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M495.19,-287.7C495.51,-280.24 495.89,-271.32 496.25,-262.97\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"499.78,-263.25 496.71,-253.1 492.79,-262.95 499.78,-263.25\"/>\n",
       "</g>\n",
       "<!-- predicted_output&#45;&gt;predicted_output_with_labels -->\n",
       "<g id=\"edge23\" class=\"edge\">\n",
       "<title>predicted_output&#45;&gt;predicted_output_with_labels</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M502.14,-215.7C504.19,-208.15 506.64,-199.12 508.93,-190.68\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"512.56,-191.67 511.8,-181.1 505.81,-189.84 512.56,-191.67\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x140ea8f10>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "materializers = [\n",
    "        to.json(\n",
    "            dependencies=[\"model_parameters\"],\n",
    "            id=\"model_params_to_json\",\n",
    "            path=\"./data/params.json\"\n",
    "        ),\n",
    "        # classification report to .txt file\n",
    "        to.file(\n",
    "            dependencies=[\"classification_report\"],\n",
    "            id=\"classification_report_to_txt\",\n",
    "            path=\"./data/classification_report.txt\",\n",
    "        ),\n",
    "        # materialize the model to a pickle file\n",
    "        to.pickle(\n",
    "            dependencies=[\"fit_clf\"], id=\"clf_to_pickle\", path=\"./data/clf.pkl\"\n",
    "        ),\n",
    "        # materialize the predictions we made to a csv file\n",
    "        to.csv(\n",
    "            dependencies=[\"predicted_output_with_labels\"],\n",
    "            id=\"predicted_output_with_labels_to_csv\",\n",
    "            path=\"./data/predicted_output_with_labels.csv\",\n",
    "        ),\n",
    "    ]\n",
    "\n",
    "dr.visualize_materialization(\n",
    "    *materializers,\n",
    "    additional_vars=[\"classification_report\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bce54351",
   "metadata": {},
   "outputs": [],
   "source": [
    "materialization_results, additional_vars = dr.materialize(\n",
    "        # materialize model parameters to json\n",
    "        *materializers,\n",
    "        additional_vars=[\"classification_report\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "282f9688",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       1.00      1.00      1.00        94\n",
      "           1       0.91      0.93      0.92        85\n",
      "           2       0.97      0.99      0.98        96\n",
      "           3       0.99      0.97      0.98        93\n",
      "           4       0.99      0.92      0.95        88\n",
      "           5       0.95      0.95      0.95        85\n",
      "           6       0.99      0.97      0.98        97\n",
      "           7       0.97      0.97      0.97        89\n",
      "           8       0.88      0.88      0.88        82\n",
      "           9       0.91      0.97      0.94        90\n",
      "\n",
      "    accuracy                           0.96       899\n",
      "   macro avg       0.95      0.95      0.95       899\n",
      "weighted avg       0.96      0.96      0.96       899\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(additional_vars['classification_report'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3df08628",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       1.00      1.00      1.00        94\n",
      "           1       0.91      0.93      0.92        85\n",
      "           2       0.97      0.99      0.98        96\n",
      "           3       0.99      0.97      0.98        93\n",
      "           4       0.99      0.92      0.95        88\n",
      "           5       0.95      0.95      0.95        85\n",
      "           6       0.99      0.97      0.98        97\n",
      "           7       0.97      0.97      0.97        89\n",
      "           8       0.88      0.88      0.88        82\n",
      "           9       0.91      0.97      0.94        90\n",
      "\n",
      "    accuracy                           0.96       899\n",
      "   macro avg       0.95      0.95      0.95       899\n",
      "weighted avg       0.96      0.96      0.96       899\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(open((materialization_results['classification_report_to_txt']['path'])).read())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
